# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define downstream task model and loss cell."""
__all__ = ["EmotionClassifier", "LossCell", "LearningRate"]

import mindspore.common.dtype as mstype
from mindspore import nn, Tensor, Parameter
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
from mindspore.ops import Cast, ScalarSummary, HistogramSummary, MatMul
from mindspore.ops import Gather
from mindspore.common.initializer import initializer


class EmotionClassifier(nn.Cell):
    def __init__(self, model):
        super(EmotionClassifier, self).__init__()
        self.pretrained_model = model.to_float(mstype.float16)
        # TODO: Please add classifier below this line.
        self.classifier = MatMul()
        self.weight = Parameter(initializer("normal", [768, 2]), name="weight")
        self.bias = Parameter(initializer("zeros", [2]), name="bias")
        self.softmax = nn.Softmax(axis=1)
        self.cast = Cast()
        # TODO: Please activate the following commented line to record weights histogram.
        # self.histogram_summary = HistogramSummary()

    def construct(self, input_ids, attention_mask, token_type_ids):
        _, pooled_output = self.pretrained_model(input_ids, attention_mask, token_type_ids)
        # (N, 768) * (768, 2) -> (N, 2)
        logits = self.classifier(pooled_output, self.cast(self.weight, mstype.float16))
        logits += self.cast(self.bias, mstype.float16)
        # TODO: Please activate the following commented line to record weights histogram.
        # self.histogram_summary("classifier_weight", self.weight)
        # self.histogram_summary("classifier_bias", self.bias)
        logits = self.cast(logits, mstype.float32)
        scores = self.softmax(logits)
        return scores

# TODO: Please add loss cell with visualizing loss curve.
# class LossCell(nn.Cell):
#     def __init__(self, model, objective_fn):
#         super(LossCell, self).__init__()
#         self.model = model
#         self.objective_fn = objective_fn
#         self.scalar_collector = ScalarSummary()
#
#     def construct(self, input_ids, attention_mask, token_type_ids, labels):
#         out = self.model(input_ids, attention_mask, token_type_ids)
#         loss = self.objective_fn(out, labels)
#         self.scalar_collector('loss', loss)
#         return loss


# TODO: Please add learning rate cell with visualizing learning rate curve.
# class LearningRate(LearningRateSchedule):
#     def __init__(self, learning_rates):
#         super(LearningRate, self).__init__()
#         self.learning_rate = Tensor(learning_rates, mstype.float32)
#         self.scalar_collector = ScalarSummary()
#
#     def construct(self, global_step):
#         lr = Gather()(self.learning_rate, global_step, 0)
#         self.scalar_collector('lr', lr)
#         return lr
